{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "# Next-frame prediction with Conv-LSTM\n",
    "\n",
    "**Author:** [jeammimi](https://github.com/jeammimi)<br>\n",
    "**Date created:** 2016/11/02<br>\n",
    "**Last modified:** 2020/05/01<br>\n",
    "**Description:** Predict the next frame in a sequence using a Conv-LSTM model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Introduction\n",
    "\n",
    "This script demonstrates the use of a convolutional LSTM model.\n",
    "The model is used to predict the next frame of an artificially\n",
    "generated movie which contains moving squares.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Setup\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "import numpy as np\n",
    "import pylab as plt\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Build a model\n",
    "\n",
    "We create a model which take as input movies of shape\n",
    "`(n_frames, width, height, channels)` and returns a movie\n",
    "of identical shape.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "seq = keras.Sequential(\n",
    "    [\n",
    "        keras.Input(\n",
    "            shape=(None, 40, 40, 1)\n",
    "        ),  # Variable-length sequence of 40x40x1 frames\n",
    "        layers.ConvLSTM2D(\n",
    "            filters=40, kernel_size=(3, 3), padding=\"same\", return_sequences=True\n",
    "        ),\n",
    "        layers.BatchNormalization(),\n",
    "        layers.ConvLSTM2D(\n",
    "            filters=40, kernel_size=(3, 3), padding=\"same\", return_sequences=True\n",
    "        ),\n",
    "        layers.BatchNormalization(),\n",
    "        layers.ConvLSTM2D(\n",
    "            filters=40, kernel_size=(3, 3), padding=\"same\", return_sequences=True\n",
    "        ),\n",
    "        layers.BatchNormalization(),\n",
    "        layers.ConvLSTM2D(\n",
    "            filters=40, kernel_size=(3, 3), padding=\"same\", return_sequences=True\n",
    "        ),\n",
    "        layers.BatchNormalization(),\n",
    "        layers.Conv3D(\n",
    "            filters=1, kernel_size=(3, 3, 3), activation=\"sigmoid\", padding=\"same\"\n",
    "        ),\n",
    "    ]\n",
    ")\n",
    "seq.compile(loss=\"binary_crossentropy\", optimizer=\"adadelta\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Generate artificial data\n",
    "\n",
    "Generate movies with 3 to 7 moving squares inside.\n",
    "The squares are of shape 1x1 or 2x2 pixels,\n",
    "and move linearly over time.\n",
    "For convenience, we first create movies with bigger width and height (80x80)\n",
    "and at the end we select a 40x40 window.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "def generate_movies(n_samples=1200, n_frames=15):\n",
    "    row = 80\n",
    "    col = 80\n",
    "    noisy_movies = np.zeros((n_samples, n_frames, row, col, 1), dtype=np.float)\n",
    "    shifted_movies = np.zeros((n_samples, n_frames, row, col, 1), dtype=np.float)\n",
    "\n",
    "    for i in range(n_samples):\n",
    "        # Add 3 to 7 moving squares\n",
    "        n = np.random.randint(3, 8)\n",
    "\n",
    "        for j in range(n):\n",
    "            # Initial position\n",
    "            xstart = np.random.randint(20, 60)\n",
    "            ystart = np.random.randint(20, 60)\n",
    "            # Direction of motion\n",
    "            directionx = np.random.randint(0, 3) - 1\n",
    "            directiony = np.random.randint(0, 3) - 1\n",
    "\n",
    "            # Size of the square\n",
    "            w = np.random.randint(2, 4)\n",
    "\n",
    "            for t in range(n_frames):\n",
    "                x_shift = xstart + directionx * t\n",
    "                y_shift = ystart + directiony * t\n",
    "                noisy_movies[\n",
    "                    i, t, x_shift - w : x_shift + w, y_shift - w : y_shift + w, 0\n",
    "                ] += 1\n",
    "\n",
    "                # Make it more robust by adding noise.\n",
    "                # The idea is that if during inference,\n",
    "                # the value of the pixel is not exactly one,\n",
    "                # we need to train the model to be robust and still\n",
    "                # consider it as a pixel belonging to a square.\n",
    "                if np.random.randint(0, 2):\n",
    "                    noise_f = (-1) ** np.random.randint(0, 2)\n",
    "                    noisy_movies[\n",
    "                        i,\n",
    "                        t,\n",
    "                        x_shift - w - 1 : x_shift + w + 1,\n",
    "                        y_shift - w - 1 : y_shift + w + 1,\n",
    "                        0,\n",
    "                    ] += (noise_f * 0.1)\n",
    "\n",
    "                # Shift the ground truth by 1\n",
    "                x_shift = xstart + directionx * (t + 1)\n",
    "                y_shift = ystart + directiony * (t + 1)\n",
    "                shifted_movies[\n",
    "                    i, t, x_shift - w : x_shift + w, y_shift - w : y_shift + w, 0\n",
    "                ] += 1\n",
    "\n",
    "    # Cut to a 40x40 window\n",
    "    noisy_movies = noisy_movies[::, ::, 20:60, 20:60, ::]\n",
    "    shifted_movies = shifted_movies[::, ::, 20:60, 20:60, ::]\n",
    "    noisy_movies[noisy_movies >= 1] = 1\n",
    "    shifted_movies[shifted_movies >= 1] = 1\n",
    "    return noisy_movies, shifted_movies\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Train the model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "epochs = 200\n",
    "\n",
    "noisy_movies, shifted_movies = generate_movies(n_samples=1200)\n",
    "seq.fit(\n",
    "    noisy_movies[:1000],\n",
    "    shifted_movies[:1000],\n",
    "    batch_size=10,\n",
    "    epochs=epochs,\n",
    "    verbose=2,\n",
    "    validation_split=0.1,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Test the model on one movie\n",
    "\n",
    "Feed it with the first 7 positions and then\n",
    "predict the new positions.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "movie_index = 1004\n",
    "track = noisy_movies[movie_index][:7, ::, ::, ::]\n",
    "\n",
    "for j in range(16):\n",
    "    new_pos = seq.predict(track[np.newaxis, ::, ::, ::, ::])\n",
    "    new = new_pos[::, -1, ::, ::, ::]\n",
    "    track = np.concatenate((track, new), axis=0)\n",
    "\n",
    "\n",
    "# And then compare the predictions\n",
    "# to the ground truth\n",
    "track2 = noisy_movies[movie_index][::, ::, ::, ::]\n",
    "for i in range(15):\n",
    "    fig = plt.figure(figsize=(10, 5))\n",
    "\n",
    "    ax = fig.add_subplot(121)\n",
    "\n",
    "    if i >= 7:\n",
    "        ax.text(1, 3, \"Predictions !\", fontsize=20, color=\"w\")\n",
    "    else:\n",
    "        ax.text(1, 3, \"Initial trajectory\", fontsize=20)\n",
    "\n",
    "    toplot = track[i, ::, ::, 0]\n",
    "\n",
    "    plt.imshow(toplot)\n",
    "    ax = fig.add_subplot(122)\n",
    "    plt.text(1, 3, \"Ground truth\", fontsize=20)\n",
    "\n",
    "    toplot = track2[i, ::, ::, 0]\n",
    "    if i >= 2:\n",
    "        toplot = shifted_movies[movie_index][i - 1, ::, ::, 0]\n",
    "\n",
    "    plt.imshow(toplot)\n",
    "    plt.savefig(\"%i_animate.png\" % (i + 1))\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "conv_lstm",
   "private_outputs": false,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}